//===- GetExtraBuffers.cpp - HIVM get extra buffer implementation ---------===//
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "bishengir/Dialect/HIVM/IR/HIVM.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include <algorithm>

using namespace mlir;
using namespace mlir::hivm;

//===----------------------------------------------------------------------===//
// Macros to help generate `getExtraBuffer`
//===----------------------------------------------------------------------===//

#define ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(OP_NAME)     \
  OperandRange OP_NAME::getExtraBuffers() { return getTempBufferMutable(); }

ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VReduceOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VSelOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VBrcOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VCastOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VXorOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VPowOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VInterleaveOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VMulextendedOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VGatherOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VTransposeOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VSortOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VSubOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VDivOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VMulOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VAddOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VMaxOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VMinOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VAndOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VOrOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VShLOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VShROp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VNotOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VAbsOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VLnOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VReluOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VExpOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VRsqrtOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VSqrtOp)
ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION(VRecOp)
#undef ENABLE_DEFAULT_OP_GET_OPTIONAL_TEMP_BUFFER_IMPLEMENTATION

//===----------------------------------------------------------------------===//
// Macros to help generate `getExtraBufferSize`
//===----------------------------------------------------------------------===//

#define ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(OP_NAME)                       \
  std::optional<int64_t> OP_NAME::getExtraBufferSize() {                       \
    llvm_unreachable("Not implemented");                                       \
  }

// Vector Binary Op
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VMulOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VAddOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VMaxOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VMinOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VAndOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VOrOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VSubOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VDivOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VShLOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VShROp)
// Vector Unary Op
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VNotOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VAbsOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VLnOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VReluOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VExpOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VRsqrtOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VSqrtOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VRecOp)

ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VSelOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VSortOp)
ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE(VXorOp)
#undef ENABLE_DEFAULT_OP_GET_EXTRA_BUFFER_SIZE
